from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import aiohttp
from collections import defaultdict
from datetime import datetime
from .base_reasoner import LLMClient

class Gsm8kReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("gsm8k")
        self.config.dataset_path = "datasets/gsm8k.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load math problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []

    async def execute_workflow(self, question: str) -> Dict[str, Any]:
        """Execute full prompt engineering workflow with voting from three methods"""
        try:
            # Run all three methods in parallel
            ho1_result, ho2_result, cot_result = await asyncio.gather(
                self._execute_ho1_workflow(question),
                self._execute_ho2_workflow(question),
                self._execute_cot_workflow(question)
            )
            
            # Collect all answers
            answers = {
                "ho1": ho1_result.get("final_answer", None),
                "ho2": ho2_result.get("final_answer", None),
                "cot": cot_result.get("answer", None)
            }

            print("\nMethod Results:")
            print(f"HoT v1 Answer: {answers['ho1']}")
            print(f"HoT v2 Answer: {answers['ho2']}")
            print(f"CoT Answer: {answers['cot']}")

            # Voting
            final_answer = self._vote_on_answers(answers)
            print(f"\nFinal Voted Answer: {final_answer}")
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "method_results": {
                    "ho1": ho1_result,
                    "ho2": ho2_result,
                    "cot": cot_result
                }
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    def _vote_on_answers(self, answers: Dict[str, Any]) -> str:
        """Voting mechanism to select the best answer"""
        # Count occurrences of each answer
        answer_counts = defaultdict(int)
        for method, answer in answers.items():
            if answer is not None:
                answer_counts[answer] += 1
        
        if not answer_counts:
            return "No valid answers found"
        
        # Find the answer with highest count
        max_count = max(answer_counts.values())
        candidates = [ans for ans, cnt in answer_counts.items() if cnt == max_count]
        
        if len(candidates) == 1:
            return candidates[0]
        else:
            if answers["ho1"] is not None:
                return answers["ho1"]
            elif answers["ho2"] is not None:
                return answers["ho2"]
            else:
                return answers["cot"] if answers["cot"] is not None else "No valid answers found"

    async def _execute_ho1_workflow(self, question: str) -> Dict[str, Any]:
        """Original HoT workflow (version 1)"""
        try:
            llm = LLMClient()
            # Step 0: Extract problem conditions
            conditions = await self._extract_conditions(question, llm=llm)
            if not isinstance(conditions, dict):
                conditions = {
                    "explicit": [],
                    "implicit": [],
                    "notes": "Invalid conditions format"
                }
            self._log_step("step0", "system", {"conditions": conditions})
            
            # Step1: Create root node with initial conditions
            root = self._create_node(
                path=[], 
                method={"description": "Original problem"}, 
                steps=[], 
                conditions={
                    "explicit": conditions.get("explicit", []),
                    "implicit": conditions.get("implicit", [])
                },
                question=question
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step2: Explore solution methods
            methods = await self._explore_solutions(question, llm=llm)
            self._log_step("step2", root.node_id, {"methods": methods})
            
            # Step3: Create method nodes
            method_nodes = []
            for method in methods[:self.config.beam_width]:
                node = self._create_node(
                    path=[root.node_id],
                    method=method,
                    steps=method.get("steps", []),
                    score=method.get("score", 0),
                    conditions=root.conditions, 
                    parent_id=root.node_id
                )
                root.children.append(node.node_id)
                method_nodes.append(node)
                self._log_step("step3", node.node_id, {"method": method})
            
            # Step4: Check if classification needed
            best_method_node = max(method_nodes, key=lambda x: x.score)
            classification_result = await self._check_classification(
                best_method_node.method["description"],
                best_method_node.steps, 
                llm=llm
            )
            self._log_step("step4", best_method_node.node_id, classification_result)
            
            if classification_result["need_classify"]:
                # Step5: Create classification nodes with combined conditions
                for case in classification_result["cases"]:
                    combined_conditions = {
                        "explicit": best_method_node.conditions["explicit"].copy(),
                        "implicit": best_method_node.conditions["implicit"].copy()
                    }
                    for k, v in case["conditions"].items():
                        if k in combined_conditions:
                            combined_conditions[k].append(v)
                        else:
                            combined_conditions["implicit"].append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=best_method_node.path + [best_method_node.node_id],
                        method=best_method_node.method,
                        steps=best_method_node.steps,
                        score=best_method_node.score,
                        conditions=combined_conditions,
                        parent_id=best_method_node.node_id,
                        question=question
                    )
                    best_method_node.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step5", node.node_id, {
                        "case": case,
                        "combined_conditions": combined_conditions
                    })
            else:
                self.temp_list.append(best_method_node.node_id)
            
            # Step6: Build temporary list
            self._log_step("step6", "system", {"temp_list": self.temp_list})
            
            # Step7: Solve nodes iteratively
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id, llm=llm)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions, llm=llm)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }

    async def _execute_ho2_workflow(self, question: str) -> Dict[str, Any]:
        """Simplified HoT workflow (version 2)"""
        try:
            llm = LLMClient()
            # Step 0: Extract problem conditions
            conditions = await self._extract_conditions(question, llm=llm)
            if not isinstance(conditions, dict):
                conditions = {
                    "explicit": [],
                    "implicit": [],
                    "notes": "Invalid conditions format"
                }
            
            # Step1: Create root node with initial conditions
            root = self._create_node(
                path=[], 
                method={"description": "Original problem"}, 
                steps=[], 
                conditions={
                    "explicit": conditions.get("explicit", []),
                    "implicit": conditions.get("implicit", [])
                },
                question=question
            )
            
            # Step4: Check if classification needed
            classification_result = await self._check_classification(
                root.method["description"],
                root.steps, 
                llm=llm
            )
            
            if classification_result["need_classify"]:
                # Step5: Create classification nodes with combined conditions
                for case in classification_result["cases"]:
                    combined_conditions = {
                        "explicit": root.conditions["explicit"].copy(),
                        "implicit": root.conditions["implicit"].copy()
                    }
                    for k, v in case["conditions"].items():
                        if k in combined_conditions:
                            combined_conditions[k].append(v)
                        else:
                            combined_conditions["implicit"].append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=[root.node_id],
                        method=root.method,
                        steps=root.steps,
                        score=0,
                        conditions=combined_conditions,
                        parent_id=root.node_id,
                        question=question
                    )
                    root.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
            else:
                self.temp_list.append(root.node_id)
            
            # Step7: Solve nodes iteratively
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id, llm=llm)
                if solution:
                    solutions.append(solution)
            
            # Step8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions, llm=llm)
            
            return {
                "status": "success",
                "final_answer": final_answer
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e)
            }

    async def _execute_cot_workflow(self, question: str) -> Dict[str, Any]:
        """Chain-of-Thought workflow"""
        try:
            llm = LLMClient()
            
            prompt = f"""
Problem: {question}
Let's think step by step, the final answer should be one exact number. provide the final answer in the format "Final Answer: your answer".
"""
            
            response = await llm.generate(prompt)
            answer = self._extract_answer(response)
            
            return {
                "response": response,
                "answer": answer
            }
            
        except Exception as e:
            print(f"CoT Error: {str(e)}")
            return {
                "status": "error",
                "message": str(e),
                "answer": None
            }

    async def _extract_conditions(self, question: str, llm: LLMClient) -> Dict:
        """Extract explicit and implicit conditions from problem"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant conditions.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Analyze this math problem and extract ALL conditions:
    
    problem:{question}
    
    Notice:
    1. Identify explicit conditions (directly stated in the problem)
    2. Derive implicit conditions (e.g., denominators ≠ 0, square roots ≥ 0, log arguments > 0)
    3. Determine domain restrictions based on mathematical principles
    4. Identify range limitations from problem context
    5. Extract physical meaning conditions (e.g., length > 0, probability ∈ [0,1])
    
    Output JSON format:
    {{
        "explicit": ["constraint1", "constraint2"],
        "implicit": ["constraint1", "constraint2"],
        "notes": "Additional analysis notes"
    }}"""
    
        for attempt in range(self.config.max_retries):
            try:
                response = await llm.generate(prompt, response_format="json_object")
                data = json.loads(response)
                
                if not isinstance(data, dict):
                    print(f"Invalid response type (attempt {attempt+1}): {type(data)}")
                    continue
                    
                conditions = {
                    "explicit": data.get("explicit", []),
                    "implicit": data.get("implicit", []),
                    "notes": data.get("notes", "")
                }
                
                if not (conditions["explicit"] or conditions["implicit"]):
                    print(f"Empty conditions (attempt {attempt+1})")
                    continue
                    
                return conditions
                
            except (json.JSONDecodeError, AttributeError) as e:
                print(f"Parse error (attempt {attempt+1}): {str(e)}")
                continue
        
        print("All retries failed, returning default conditions")
        return {
            "explicit": ["Default explicit constraint"],
            "implicit": ["Default implicit constraint"],
            "notes": "Fallback conditions"
        }
    
    async def _explore_solutions(self, question: str, llm: LLMClient) -> List[Dict]:
        """Step2: Explore diverse solution methods"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant conditions.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Generate 3 distinct solution methods for:

{question}

Notice:
1. Employ different theoretical frameworks (algebraic, geometric, analytical, etc.)
2. Approach from fundamentally different perspectives
3. Vary implementation techniques significantly
4. Consider both conventional and innovative methods
5. Steps can be retained as ideas only, without exact calculations
6. Pay attention to the mathematical expressions in the questions and understand them correctly
7. examine carefully the subject matter

For each method, provide:
- Clear description of the mathematical approach
- Step-by-step implementation plan
- Effectiveness score (0-100) based on:
  * Mathematical rigor
  * Computational feasibility
  * Logical completeness
  * Efficiency

Output JSON format:
{{
    "methods": [
        {{
            "description": "Method description",
            "steps": ["step1", "step2"],
            "score": 0-100,
            "score_reason": "Scoring justification"
        }}
    ]
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await llm.generate(prompt, response_format="json_object")
                response = response.strip()
                
                if response.startswith("```json"):
                    response = response[7:-3].strip()
                elif response.startswith("```"):
                    response = response[3:-3].strip()
                
                response = response.replace('\\', '\\\\')
                
                data = json.loads(response)
                
                if not isinstance(data, dict) or "methods" not in data:
                    raise ValueError("Invalid structure: missing 'methods' key")
                    
                methods = data["methods"]
                if len(methods) != 3:
                    raise ValueError(f"Expected 3 methods, got {len(methods)}")
                    
                required_keys = {"description", "steps", "score", "score_reason"}
                for method in methods:
                    if not all(k in method for k in required_keys):
                        raise ValueError("Missing required keys in method")
                    if not isinstance(method["steps"], list):
                        raise ValueError("Steps must be a list")
                        
                return sorted(methods, key=lambda x: -x["score"])
                
            except (json.JSONDecodeError, ValueError, KeyError) as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    print(f"Final failed response: {response}")
                    return []
                continue
                
        return [] 
    
    async def _check_classification(self, method: str, steps: List[str], llm: LLMClient) -> Dict[str, Any]:
        """Step4: Determine if classification needed"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant conditions.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Determine if this solution requires classification:

Method: {method}
Steps: {steps}

Notice:
1. Identify parameter dependencies requiring discussion
2. Detect interval-specific elements (absolute values, piecewise functions)
3. Recognize domain-specific computation requirements
4. Flag multiple solution sets needing verification
5. Pay attention to the mathematical expressions in the questions and understand them correctly
6. examine carefully the subject matter

If classification needed, provide:
- Comprehensive case descriptions
- Precise mathematical conditions for each case
- Clear boundary conditions

Output JSON format:
{{
    "need_classify": true/false,
    "reason": "Classification rationale",
    "cases": [
        {{
            "description": "Case description",
            "conditions": {{"parameter": "value_range"}}
        }}
    ]
}}"""
        
        response = await llm.generate(prompt, response_format="json_object")
        try:
            data = json.loads(response)
            return {
                "need_classify": data.get("need_classify", False),
                "reason": data.get("reason", ""),
                "cases": data.get("cases", [])
            }
        except json.JSONDecodeError:
            print(f"Failed to parse classification response: {response}")
            return {"need_classify": False, "reason": "Parse failed", "cases": []}
    
    async def _solve_node(self, node_id: str, llm: LLMClient) -> Optional[Dict[str, Any]]:
        """Step7: Solve individual node"""
        node = self.nodes[node_id]
        root_node = self.nodes[node.path[0]] if node.path else node
        original_question = getattr(node, 'question', None) or getattr(root_node, 'question', "Original problem")
        
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant conditions.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    You are a meticulous mathematical problem solver executing this solution:
    
    Original Problem: {original_question}
    Steps: {node.steps}
    conditions: {node.conditions}
    
    As an executor, you must:
    1. Follow the provided steps precisely
    2. Explicitly verify all conditions at each step
    3. Show complete mathematical justification
    4. Use proper mathematical notation
    5. Clearly mark the final answer with \\boxed{{}}
    6. Include standalone line: "Final Answer: answer"
    7. Ensure your answer directly responds to the question asked
    8. The final answer should be one exact number
    9. Pay attention to the mathematical expressions in the questions and understand them correctly
    10. examine carefully the subject matter
    
    Additional requirements:
    - Show all intermediate calculations
    - State any assumptions made
    - Verify solution satisfies all conditions
    - Cross-validate critical steps
    - If the question asks for GCD, provide only the GCD as final answer
    - If you calculate intermediate values (like A and B), clearly distinguish them from the final answer"""
        
        response = await llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None

    
    async def _aggregate_answers(self, solutions: List[Dict[str, Any]], llm: LLMClient) -> str:
        """Step8: Aggregate solutions with original question"""
        if not solutions:
            return "No valid solutions found"
        
        original_question = None
        for sol in solutions:
            node = self.nodes[sol["node_id"]]
            if hasattr(node, 'original_question'):
                original_question = node.original_question
                break
        
        if original_question is None:
            first_node = self.nodes[solutions[0]["node_id"]]
            path = first_node.path
            if path: 
                root_node_id = path[0]
                root_node = self.nodes.get(root_node_id)
                if root_node:
                    original_question = root_node.method.get("description", "Original problem")
        
        if original_question is None:
            original_question = "Original problem (reconstructed from context)"
            if solutions[0]["response"]:
                match = re.search(r'Original Problem[:\s]*(.+?)\nSteps:', solutions[0]["response"])
                if match:
                    original_question = match.group(1).strip()
        
        if len(solutions) == 1:
            return solutions[0]["answer"]
        
        unique_answers = {sol["answer"] for sol in solutions}
        if len(unique_answers) == 1:
            return solutions[0]["answer"]
        
        solutions_text = "\n\n".join(
            f"Solution {i+1} (Node: {sol['node_id']}):\n"
            f"Answer: {sol['answer']}\n"
            f"Approach: {self.nodes[sol['node_id']].method['description']}\n"
            f"conditions: {self.nodes[sol['node_id']].conditions}\n"
            f"Reasoning Excerpt:\n{sol['response'][:300]}...\n"
            for i, sol in enumerate(solutions)
        )
        
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant conditions.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Synthesize these solutions for the original problem:
    
    Original Problem: {original_question}
    
    Proposed Solutions:
    {solutions_text}
    
    As an analyst, you must:
    1. FIRST verify which solution(s) correctly answer the original question
    2. Compare mathematical consistency with the original problem statement
    3. Evaluate which approach best satisfies all conditions
    4. Combine elements from multiple solutions ONLY if mathematically valid
    5. Provide clear justification for your selection
    6. Mark final answer with \\boxed{{}}
    7. Include standalone line: "Aggregated Answer: answer"
    
    Critical Analysis Guidelines:
    - The solution MUST directly answer the original question as stated
    - Prioritize mathematical correctness over elegance
    - Reject solutions that violate any explicit conditions
    - Verify all intermediate calculations are sound
    - Ensure the final answer format matches what the problem requires"""
    
        response = await llm.generate(prompt)
        return self._extract_answer(response) or "Aggregation failed"
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            return boxed_matches[-1] 
    
        final_answer_match = re.search(
            r'Final\s+Answer\s*:\s*([^\n]+)', 
            text, 
            re.IGNORECASE
        )
        if final_answer_match:
            return final_answer_match.group(1).strip()
    
        return None

    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        """Save simplified results without detailed process information"""
        verification = {
            "is_correct": False,
            "correct_answer": None,
            "given_answer": result.get("final_answer")
        }
        
        if "answer" in problem:
            correct_answer = None
            if "solution" in problem:
                correct_answer = self._extract_correct_answer(problem["solution"])
            elif "answer" in problem:
                correct_answer = self._extract_correct_answer(problem["answer"])

            verification["correct_answer"] = correct_answer
            
            if correct_answer is not None and "final_answer" in result:
                given = str(result["final_answer"]).strip()
                expected = str(correct_answer).strip()
                
                if len(expected) == 1 and given.endswith(expected):
                    verification["is_correct"] = True
                else:
                    try:
                        given_num = float(given)
                        expected_num = float(expected)
                        if abs(given_num - expected_num) < 1e-10:  
                            verification["is_correct"] = True
                    except ValueError:
                        pass 
        
        return {
            "question": problem["question"],
            "final_answer": result["final_answer"],
            "method_answers": result.get("method_answers", {}),
            "verification": verification
        }
    
    def _extract_correct_answer(self, solution: str) -> Optional[str]:
        """Extract correct answer from solution's \boxed{}"""
        hash_pattern = r'####\s*([^\n]+)'
        hash_matches = re.findall(hash_pattern, solution)
        return hash_matches[-1].strip() if hash_matches else None
    
    async def verify_answer(self, problem: Dict[str, Any], final_answer: str) -> bool:
        """Verify if final answer matches correct solution"""
        if "solution" not in problem:
            return False
            
        correct_answer = self._extract_correct_answer(problem["solution"])
        if not correct_answer:
            return False
            
        return str(final_answer).strip() == str(correct_answer).strip()